import argparse
import copy
import dataclasses
import json
import os
from enum import IntFlag, auto
from functools import cached_property
from typing import Dict, List, Optional, Union

import numpy as np
import safetensors
import torch

from .._common import default_net
from .._utils import (get_init_params, numpy_to_torch, release_gc,
                      str_dtype_to_torch, str_dtype_to_trt, trt_dtype_to_torch)
from ..functional import PositionEmbeddingType, Tensor, gather_last_token_logits
from ..layers import (AttentionParams, Embedding, FusedGatedMLP, FusedRgLru,
                      GatedMLP, KeyValueCacheParams, LoraParams,
                      PromptTuningEmbedding, RgLru)
from ..layers.attention import Attention, BertAttention
from ..layers.linear import ColumnLinear, Linear, RowLinear
from ..layers.lora import Lora
from ..layers.moe import MOE, MoeOOTB
from ..logger import logger
from ..mapping import Mapping
from ..module import Module, ModuleList
from ..parameter import Parameter
from ..quantization import QuantMode
from ..quantization.mode import W8A8_SQ_PLUGIN_LIST, QuantAlgo
from ..top_model_mixin import TopModelMixin
from .convert_utils import weight_only_quantize_dict
from .generation_mixin import GenerationMixin

WEIGHT_LOADER_MODELS = {"PhiForCausalLM"}


class SpeculativeDecodingMode(IntFlag):
    # [WARNING] KEEP BELOW DEFINITION IN SYNC WITH cpp/tensorrt_llm/runtime/speculativeDecodingMode.h
    NONE = auto()
    DRAFT_TOKENS_EXTERNAL = auto()
    MEDUSA = auto()
    LOOKAHEAD_DECODING = auto()
    EXPLICIT_DRAFT_TOKENS = auto()

    @staticmethod
    def from_arguments(args: argparse.Namespace):
        if args.speculative_decoding_mode is None:
            return SpeculativeDecodingMode.NONE
        elif args.speculative_decoding_mode == "draft_tokens_external":
            return SpeculativeDecodingMode.DRAFT_TOKENS_EXTERNAL
        elif args.speculative_decoding_mode == "medusa":
            return SpeculativeDecodingMode.MEDUSA
        elif args.speculative_decoding_mode == "lookahead_decoding":
            return SpeculativeDecodingMode.LOOKAHEAD_DECODING
        elif args.speculative_decoding_mode == "explicit_draft_tokens":
            return SpeculativeDecodingMode.EXPLICIT_DRAFT_TOKENS
        else:
            assert False, "Unknown speculative_decoding_mode " + args.speculative_decoding_mode


@dataclasses.dataclass
class QuantConfig:
    '''Serializable quantization configuration class, part of the PretrainedConfig
    '''

    quant_algo: Optional[QuantAlgo] = None
    kv_cache_quant_algo: Optional[QuantAlgo] = None
    group_size: Optional[int] = 128
    smoothquant_val: Optional[float] = None
    has_zero_point: Optional[bool] = False
    pre_quant_scale: Optional[bool] = False
    exclude_modules: Optional[List[str]] = None

    @property
    def use_plugin_sq(self):
        return self.quant_algo in W8A8_SQ_PLUGIN_LIST

    @cached_property
    def quant_mode(self) -> QuantMode:
        return QuantMode.from_quant_algo(
            self.quant_algo,
            self.kv_cache_quant_algo,
        )

    def quant_algo_to_modelopt_qformat(self):
        algo_to_modelopt_map = {
            QuantAlgo.W8A16: "int8_wo",
            QuantAlgo.W4A16: "int4_wo",
            QuantAlgo.W4A16_AWQ: "int4_awq",
            QuantAlgo.W4A8_AWQ: 'w4a8_awq',
            QuantAlgo.FP8: 'fp8',
            QuantAlgo.W8A8_SQ_PER_CHANNEL: 'int8_sq',
        }
        if self.quant_algo is not None:
            assert self.quant_algo in algo_to_modelopt_map, f"We don't use Modelopt for quantization algorithm {self.quant_algo}, you probably shall not call this"
            qformat = algo_to_modelopt_map[self.quant_algo]
        else:
            qformat = 'full_prec'
        return qformat

    @classmethod
    def from_dict(cls, config: dict):
        return cls(**config)

    def to_dict(self):
        return dataclasses.asdict(self)


def default_weight_loader(mapping: Mapping, param: torch.Tensor,
                          loaded_weight: torch.Tensor) -> None:
    """Default weight loader."""
    param.value = loaded_weight


def save_checkpoint(output_dir: str, config: dict, weights: dict) -> None:
    """ Checkpoint saver for weight loader."""
    with open(os.path.join(output_dir, 'config.json'), 'w') as f:
        json.dump(config, f, indent=4)
    safetensors.torch.save_file(weights,
                                os.path.join(output_dir, 'rank0.safetensors'))


class PretrainedConfig:

    def __init__(self,
                 *,
                 architecture: str,
                 dtype: str,
                 hidden_size: int,
                 num_hidden_layers: int,
                 num_attention_heads: int,
                 vocab_size: Optional[int] = None,
                 hidden_act: str = 'gelu',
                 logits_dtype: str = 'float32',
                 norm_epsilon: float = 1e-5,
                 position_embedding_type: Union[
                     PositionEmbeddingType,
                     str] = PositionEmbeddingType.learned_absolute,
                 max_position_embeddings: Optional[int] = None,
                 num_key_value_heads: Optional[int] = None,
                 intermediate_size: Optional[int] = None,
                 mapping: Optional[Union[Mapping, dict]] = None,
                 quantization: Optional[Union[QuantConfig, dict]] = None,
                 use_parallel_embedding: bool = False,
                 embedding_sharding_dim: int = 0,
                 share_embedding_table: bool = False,
                 head_size: Optional[int] = None,
                 qk_layernorm: bool = False,
                 **kwargs):
        self.architecture = architecture
        self.dtype = dtype
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.hidden_act = hidden_act

        self.logits_dtype = logits_dtype
        self.norm_epsilon = norm_epsilon

        if isinstance(position_embedding_type, str):
            position_embedding_type = PositionEmbeddingType.from_string(
                position_embedding_type)
        assert isinstance(position_embedding_type, PositionEmbeddingType)
        self.position_embedding_type = position_embedding_type

        self.max_position_embeddings = max_position_embeddings

        if num_key_value_heads is None:
            num_key_value_heads = num_attention_heads
        self.num_key_value_heads = num_key_value_heads

        if intermediate_size is None:
            intermediate_size = hidden_size * 4
        self.intermediate_size = intermediate_size

        if mapping is None:
            mapping = Mapping()
        elif isinstance(mapping, dict):
            mapping = Mapping.from_dict(mapping)
        assert isinstance(mapping, Mapping)
        self.mapping = mapping

        if quantization is None:
            quantization = QuantConfig()
        elif isinstance(quantization, dict):
            quantization = QuantConfig.from_dict(quantization)
        assert isinstance(quantization, QuantConfig)
        self.quantization = quantization

        self.use_parallel_embedding = use_parallel_embedding
        self.embedding_sharding_dim = embedding_sharding_dim
        self.share_embedding_table = share_embedding_table

        if share_embedding_table and mapping.tp_size > 1:
            if (not use_parallel_embedding) or (use_parallel_embedding and
                                                embedding_sharding_dim == 1):
                raise NotImplementedError(
                    "For tensor parallelism, sharing the embedding table must set" \
                        "use_parallel_embedding=True and embedding_sharding_dim=0"
                )
        if share_embedding_table and mapping.pp_size > 1:
            raise NotImplementedError(
                "Embedding table cannot be shared for pipeline parallelism")

        if head_size is None:
            head_size = hidden_size // num_attention_heads
        self.head_size = head_size
        self.qk_layernorm = qk_layernorm

        for key, value in kwargs.items():
            try:
                setattr(self, key, value)
                logger.warning(
                    f"Implicitly setting {self.__class__.__name__}.{key} = {value}"
                )
            except AttributeError as err:
                raise err

    @property
    def kv_dtype(self):
        if self.quant_mode.has_int8_kv_cache():
            return 'int8'
        elif self.quant_mode.has_fp8_kv_cache():
            return 'fp8'
        else:
            return self.dtype

    def set_if_not_exist(self, key, value):
        if not hasattr(self, key):
            setattr(self, key, value)

    @classmethod
    def from_dict(cls, config: dict):
        # Maybe we need AutoConfig for this
        from . import MODEL_MAP
        model_cls = MODEL_MAP[config['architecture']]
        config_cls = getattr(model_cls, 'config_class', cls)
        return config_cls(**config)

    def to_dict(self):
        output = copy.deepcopy(self.__dict__)

        output['position_embedding_type'] = str(self.position_embedding_type)
        output['mapping'] = self.mapping.to_dict()
        output['mapping'].pop('rank')
        output['quantization'] = self.quantization.to_dict()

        return output

    @classmethod
    def from_json_file(cls, config_file: str):
        with open(config_file) as f:
            config = json.load(f)
        return cls.from_dict(config)

    @classmethod
    def from_checkpoint(cls, ckpt_dir: str):
        return cls.from_json_file(os.path.join(ckpt_dir, 'config.json'))

    def to_json_file(self, config_file: str):
        with open(config_file, 'w') as f:
            json.dump(self.to_dict(), f, indent=4)

    @property
    def quant_mode(self):
        return self.quantization.quant_mode

    def set_rank(self, rank):
        self.mapping = Mapping(self.mapping.world_size,
                               rank=rank,
                               tp_size=self.mapping.tp_size,
                               pp_size=self.mapping.pp_size,
                               gpus_per_node=self.mapping.gpus_per_node)


class DecoderLayerList(ModuleList):

    def __init__(self, cls, config):
        self.num_hidden_layers = config.num_hidden_layers
        self.layer_list = config.mapping.pp_layers(config.num_hidden_layers)
        super().__init__([cls(config, idx) for idx in self.layer_list])

    def forward(self,
                hidden_states,
                use_cache=False,
                attention_mask=None,
                kv_cache_params=None,
                attention_params=None,
                position_ids=None,
                lora_params=None,
                spec_decoding_params=None):
        kv_cache_params.fill_none_tensor_list(len(self.layer_list))

        if use_cache:
            presents = []

        for layer_idx, (layer, past) in enumerate(
                zip(self, kv_cache_params.past_key_value)):

            lora_layer_params = None
            if lora_params is not None and lora_params.lora_ranks is not None:
                lora_layer_params = lora_params.get_layer_params(layer_idx)

            kwargs = {}
            if position_ids is not None:
                kwargs['position_ids'] = position_ids
            if lora_layer_params is not None:
                kwargs['lora_layer_params'] = lora_layer_params
            if spec_decoding_params is not None:
                kwargs['spec_decoding_params'] = spec_decoding_params
            if default_net().plugin_config.reduce_fusion:
                if layer_idx < self.layer_list[-1]:
                    kwargs['next_layer_input_layernorm_args'] = (
                        self[layer_idx + 1].input_layernorm.weight.value,
                        self[layer_idx + 1].input_layernorm.eps)
                else:
                    kwargs['next_layer_input_layernorm_args'] = None

            hidden_states = layer(
                hidden_states,
                use_cache=use_cache,
                attention_mask=attention_mask,
                kv_cache_params=KeyValueCacheParams(
                    past_key_value=[past],
                    host_past_key_value_lengths=kv_cache_params.
                    host_past_key_value_lengths,
                    host_max_attention_window_sizes=kv_cache_params.
                    host_max_attention_window_sizes,
                    host_sink_token_length=kv_cache_params.
                    host_sink_token_length,
                    kv_cache_block_offsets=kv_cache_params.
                    kv_cache_block_offsets,
                    host_kv_cache_block_offsets=kv_cache_params.
                    host_kv_cache_block_offsets,
                    host_kv_cache_pool_pointers=kv_cache_params.
                    host_kv_cache_pool_pointers,
                    cache_indirection=kv_cache_params.cache_indirection),
                attention_params=attention_params,
                **kwargs)

            if use_cache:
                presents.append(hidden_states[1])
                hidden_states = hidden_states[0]

        if use_cache:
            return hidden_states, presents
        return hidden_states


class PostInitCaller(type):

    def __call__(cls, *args, **kwargs):
        obj = type.__call__(cls, *args, **kwargs)
        obj.__post_init__()
        return obj


class PretrainedModel(Module,
                      GenerationMixin,
                      TopModelMixin,
                      metaclass=PostInitCaller):

    def __init__(self, config: PretrainedConfig):
        super().__init__()
        self.config = config

    def __post_init__(self):
        from ..quantization.quantize import quantize
        quantize(self, self.config.quantization)

        # Currently, use_parallel_embedding and share_embedding_table must be enabled before weight loading;
        # otherwise, the model will be inconsistent with the weights loaded from checkpoint.
        optimize_model(
            self,
            use_parallel_embedding=self.config.use_parallel_embedding,
            share_embedding_table=self.config.share_embedding_table,
        )

    def release(self):
        release_gc()

    def __del__(self):
        self.release()

    def check_config(self, config):
        raise NotImplementedError(
            f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
        )

    @classmethod
    def from_config(cls, config: PretrainedConfig):
        return cls(config)

    @classmethod
    def from_checkpoint(cls,
                        ckpt_dir: str,
                        rank: Optional[int] = None,
                        config: Optional[PretrainedConfig] = None):
        if config is None:
            config = PretrainedConfig.from_json_file(
                os.path.join(ckpt_dir, 'config.json'))

        if rank is not None:
            config.set_rank(rank)

        model = cls(config)
        weights = None
        if config.architecture in WEIGHT_LOADER_MODELS:
            weights_path = os.path.join(ckpt_dir, 'rank0.safetensors')
        else:
            rank = config.mapping.rank
            weights_path = os.path.join(ckpt_dir, f'rank{rank}.safetensors')

        assert os.path.isfile(weights_path)
        weights = safetensors.torch.load_file(weights_path)

        is_checkpoint_pruned = getattr(config, 'is_pruned', False)
        if weights is not None:
            preprocess_weights(weights,
                               config,
                               from_pruned=is_checkpoint_pruned)
            model.load(weights, from_pruned=is_checkpoint_pruned)

        return model

    def load(self, weights, from_pruned=False):
        expected_names = set()
        required_names = set()
        for name, param in self.named_parameters():
            expected_names.add(name)
            if not param.is_inited():
                required_names.add(name)

        provided_names = set(weights.keys())
        if not required_names.issubset(provided_names):
            raise RuntimeError(
                f"Required but not provided tensors:{required_names.difference(provided_names)}"
            )
        if not provided_names.issubset(expected_names):
            logger.warning(
                f"Provided but not expected tensors: {provided_names.difference(expected_names)}"
            )

        if self.config.architecture in WEIGHT_LOADER_MODELS:
            mapping = self.config.mapping
            for name, param in self.named_parameters():
                if name in provided_names:
                    weight_loader = getattr(param, "weight_loader",
                                            default_weight_loader)
                    if from_pruned and param._shape != weights[name].shape:
                        dummy_weight = torch.empty(param._shape,
                                                   dtype=trt_dtype_to_torch(
                                                       param._dtype))
                        weight_loader(mapping, param, dummy_weight)
                    else:
                        weight_loader(mapping, param, weights[name])
        else:
            for name, param in self.named_parameters():
                if name in provided_names:
                    if not from_pruned:
                        try:
                            param.value = weights[name]
                        except Exception as e:
                            raise RuntimeError(
                                f"Encounter error '{e}' for parameter '{name}'")
                    else:
                        param.set_value_or_dummy(weights[name])

    def load_partial_weights(self, weights: dict):
        params = {name: param for name, param in self.named_parameters()}
        mapping = self.config.mapping

        for k, v in weights.items():
            if k in params.keys():
                param = params[k]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(mapping, param, v)
            elif mapping.pp_size == 1:
                logger.warning(f"Provided but not expected tensors: {k}")

    def save_checkpoint(self, output_dir, save_config=True):
        # multiple ranks could share same config.json, so adding a save_config parameter to let user avoiding writing config.json in all ranks
        rank = self.config.mapping.rank
        weights = {
            name: numpy_to_torch(param.raw_value)
            for name, param in self.named_parameters()
        }
        safetensors.torch.save_file(
            weights, os.path.join(output_dir, f'rank{rank}.safetensors'))
        if save_config:
            self.config.to_json_file(os.path.join(output_dir, 'config.json'))

    def prepare_inputs(self,
                       max_batch_size,
                       max_input_len,
                       max_seq_len,
                       max_num_tokens,
                       use_cache,
                       max_beam_width: int = 1,
                       opt_num_tokens: int = None,
                       prompt_embedding_table_size: int = 0,
                       position_encoding_2d: bool = False,
                       max_draft_len: int = 0,
                       speculative_decoding_draft_tokens_external: bool = False,
                       gather_context_logits: bool = False,
                       gather_generation_logits: bool = False,
                       lora_target_modules: List[str] = None,
                       opt_batch_size: int = 0):
        '''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the
            ranges of the dimensions of when using TRT dynamic shapes.

            @return: a list contains values which can be fed into the self.forward()
        '''

        # Prepare inputs
        remove_input_padding = default_net().plugin_config.remove_input_padding
        use_gpt_attention_plugin = default_net(
        ).plugin_config.gpt_attention_plugin
        use_gemm_plugin = default_net().plugin_config.gemm_plugin
        paged_kv_cache = default_net().plugin_config.paged_kv_cache
        tokens_per_block = default_net().plugin_config.tokens_per_block
        use_custom_all_reduce = default_net(
        ).plugin_config.use_custom_all_reduce
        use_lora_plugin = default_net().plugin_config.lora_plugin
        multiple_profiles = default_net().plugin_config.multiple_profiles
        streamingllm = default_net().plugin_config.streamingllm

        model_inputs = self.prepare_basic_inputs(
            max_batch_size=max_batch_size,
            max_beam_width=max_beam_width,
            max_input_len=max_input_len,
            max_seq_len=max_seq_len,
            hidden_size=self.config.hidden_size,
            num_kv_heads=self.config.num_key_value_heads,
            head_size=self.config.head_size,
            num_layers=self.config.num_hidden_layers,
            kv_dtype=str_dtype_to_trt(self.config.kv_dtype),
            remove_input_padding=remove_input_padding,
            use_gpt_attention_plugin=use_gpt_attention_plugin,
            use_gemm_plugin=use_gemm_plugin,
            paged_kv_cache=paged_kv_cache,
            tokens_per_block=tokens_per_block,
            num_heads=self.config.num_attention_heads,
            max_num_tokens=max_num_tokens,
            opt_num_tokens=opt_num_tokens,
            dtype=str_dtype_to_trt(self.config.dtype),
            prompt_embedding_table_size=prompt_embedding_table_size,
            position_encoding_2d=position_encoding_2d,
            mapping=self.config.mapping,
            gather_context_logits=gather_context_logits,
            gather_generation_logits=gather_generation_logits,
            use_custom_all_reduce=use_custom_all_reduce,
            use_lora_plugin=use_lora_plugin,
            max_draft_len=max_draft_len,
            speculative_decoding_draft_tokens_external=
            speculative_decoding_draft_tokens_external,
            lora_target_modules=lora_target_modules,
            multiple_profiles=multiple_profiles,
            streamingllm=streamingllm,
            opt_batch_size=opt_batch_size)

        result = {
            'input_ids':
            model_inputs['input_ids'],
            'position_ids':
            model_inputs['position_ids'],
            'use_cache':
            True,
            'last_token_ids':
            model_inputs['last_token_ids'],
            'attention_mask':
            model_inputs['attention_mask'],
            'kv_cache_params':
            KeyValueCacheParams(
                past_key_value=model_inputs['past_key_value'],
                host_past_key_value_lengths=model_inputs[
                    'host_past_key_value_lengths'],
                host_max_attention_window_sizes=model_inputs[
                    'host_max_attention_window_sizes'],
                host_sink_token_length=model_inputs['host_sink_token_length'],
                kv_cache_block_offsets=model_inputs['kv_cache_block_offsets'],
                host_kv_cache_block_offsets=model_inputs[
                    'host_kv_cache_block_offsets'],
                host_kv_cache_pool_pointers=model_inputs[
                    'host_kv_cache_pool_pointers'],
                cache_indirection=model_inputs['cache_indirection'],
            ),
            'attention_params':
            AttentionParams(
                sequence_length=model_inputs['sequence_length'],
                context_lengths=model_inputs['context_lengths'],
                host_context_lengths=model_inputs['host_context_lengths'],
                max_context_length=max_input_len,
                host_request_types=model_inputs['host_request_types'])
        }

        if prompt_embedding_table_size > 0:
            result['prompt_embedding_table'] = model_inputs[
                'prompt_embedding_table']
            result['prompt_tasks'] = model_inputs['tasks']
            result['prompt_vocab_size'] = model_inputs['prompt_vocab_size']
        if model_inputs['hidden_states_input'] is not None:
            result['hidden_states'] = model_inputs['hidden_states_input']
        if use_lora_plugin:
            result['lora_params'] = LoraParams(
                model_inputs['lora_ranks'],
                model_inputs['lora_weights_pointers'],
                host_context_lengths=model_inputs['host_context_lengths'],
                max_context_length=max_input_len,
                host_request_types=model_inputs['host_request_types'])
        if model_inputs['spec_decoding_params'] is not None:
            result['spec_decoding_params'] = model_inputs[
                'spec_decoding_params']

        return result

    @classmethod
    def quantize(
        cls,
        hf_model_dir: str,
        output_dir: str,
        dtype: str = 'float16',
        mapping: Optional[Mapping] = None,
        quant_config: Optional[QuantConfig] = None,
        *,
        calib_dataset='cnn_dailymail',
        calib_batches=512,
        calib_batch_size=1,
        calib_max_seq_length=512,
        random_seed=1234,
        tokenizer_max_seq_length=2048,
    ):
        if mapping is None:  # single gpu
            mapping = Mapping()
        modelopt_qformat = quant_config.quant_algo_to_modelopt_qformat()
        kv_cache_dtype = quant_config.kv_cache_quant_algo
        assert modelopt_qformat is not None
        from ..quantization import quantize_and_export
        hf_model_dir = str(
            hf_model_dir)  # quantize_and_export has some code can not take Path
        quantize_and_export(
            model_dir=hf_model_dir,
            calib_dataset=calib_dataset,
            dtype=dtype,
            qformat=modelopt_qformat,
            kv_cache_dtype=kv_cache_dtype,
            calib_size=calib_batches,
            batch_size=calib_batch_size,
            calib_max_seq_length=calib_max_seq_length,
            awq_block_size=quant_config.group_size,
            output_dir=output_dir,
            tp_size=mapping.tp_size,
            pp_size=mapping.pp_size,
            seed=random_seed,
            tokenizer_max_seq_length=tokenizer_max_seq_length,
        )


class DecoderModelForCausalLM(PretrainedModel):

    def __init__(self, config: PretrainedConfig, transformer, lm_head):
        super().__init__(config)
        self.transformer = transformer
        self.lm_head = lm_head
        config.set_if_not_exist('mup_width_multiplier', 1.0)
        self.mup_width_multiplier = config.mup_width_multiplier

    def forward(self,
                input_ids: Tensor,
                position_ids=None,
                use_cache=False,
                last_token_ids=None,
                attention_mask=None,
                kv_cache_params=None,
                attention_params=None,
                hidden_states=None,
                prompt_embedding_table: Optional[Tensor] = None,
                prompt_tasks: Optional[Tensor] = None,
                prompt_vocab_size: Optional[Tensor] = None,
                lora_params=None,
                spec_decoding_params=None):
        kwargs = {
            'input_ids': input_ids,
            'position_ids': position_ids,
            'use_cache': use_cache,
            'attention_mask': attention_mask,
            'kv_cache_params': kv_cache_params,
            'attention_params': attention_params,
        }
        if lora_params is not None:
            kwargs['lora_params'] = lora_params
        if hidden_states is not None:
            kwargs['hidden_states'] = hidden_states
        if prompt_embedding_table is not None:
            kwargs['prompt_embedding_table'] = prompt_embedding_table
        if prompt_tasks is not None:
            kwargs['prompt_tasks'] = prompt_tasks
        if prompt_vocab_size is not None:
            kwargs['prompt_vocab_size'] = prompt_vocab_size

        if spec_decoding_params is not None:
            kwargs['spec_decoding_params'] = spec_decoding_params

        hidden_states = self.transformer.forward(**kwargs)

        if use_cache:
            hidden_states, presents = hidden_states

        if self.config.mapping.is_last_pp_rank():
            hidden_states = gather_last_token_logits(
                hidden_states, last_token_ids,
                default_net().plugin_config.remove_input_padding)

            # [batch_size, hidden_size] -> [batch_size, vocab_size]
            lm_logits = self.lm_head(hidden_states)
            if hasattr(self.config, 'output_multiplier_scale'):
                lm_logits *= getattr(self.config, 'output_multiplier_scale', 1)
            if self.mup_width_multiplier is not None:
                lm_logits = lm_logits / self.mup_width_multiplier
            lm_logits.mark_output('logits', self.config.logits_dtype)
        else:
            hidden_states.mark_output('hidden_states_output', self.config.dtype)

        if use_cache and not default_net().plugin_config.paged_kv_cache:
            for i, present in zip(
                    self.config.mapping.pp_layers(
                        self.config.num_hidden_layers), presents):
                present.mark_output(f'present_key_value_{i}',
                                    self.config.kv_dtype)
            if self.config.mapping.is_last_pp_rank():
                return (lm_logits, presents, hidden_states)
            return (hidden_states, presents)
        else:
            if self.config.mapping.is_last_pp_rank():
                return lm_logits, hidden_states
            return hidden_states


def fuse_gate_mlp(
    model: PretrainedModel,
    gemm_swiglu_plugin_dtype: Optional[str] = None,
) -> PretrainedModel:
    from ..quantization.quantize import fp8_quantize

    quant_algo = model.config.quantization.quant_algo
    for name, mlp, layer in model.named_modules_with_parent():
        if isinstance(mlp, GatedMLP):
            init_params = get_init_params(mlp)
            init_params["inner_layernorm"] = mlp.inner_layernorm is not None
            fused_layer = FusedGatedMLP(**init_params)

            if quant_algo == QuantAlgo.FP8:
                fused_layer = fp8_quantize(fused_layer,
                                           model.config.quantization)

                if isinstance(mlp.dtype, str):
                    dtype = str_dtype_to_torch(mlp.dtype)
                else:
                    dtype = trt_dtype_to_torch(mlp.dtype)

                # dequantize
                gate_weight = numpy_to_torch(
                    mlp.gate.weight.raw_value).to(dtype) * numpy_to_torch(
                        mlp.gate.weights_scaling_factor.raw_value)
                fc_weight = numpy_to_torch(
                    mlp.fc.weight.raw_value).to(dtype) * numpy_to_torch(
                        mlp.fc.weights_scaling_factor.raw_value)

                # concat
                fused_weight = torch.cat([gate_weight, fc_weight], dim=0)

                # quantize
                fused_weight_scaling_factor = numpy_to_torch(
                    max(
                        mlp.gate.weights_scaling_factor.raw_value,
                        mlp.fc.weights_scaling_factor.raw_value,
                    ))
                fused_weight = (fused_weight / fused_weight_scaling_factor).to(
                    torch.float8_e4m3fn)

                if gemm_swiglu_plugin_dtype == 'fp8':
                    # gemm_swiglu_plugin needs (k, n) weights
                    # but weights should still be k-major for fp8
                    fused_layer.fused_fc.weight = Parameter(
                        shape=(fused_layer.fused_fc.in_features,
                               fused_layer.fused_fc.out_features),
                        dtype='fp8')
                    fused_layer.fused_fc.weight.value = fused_weight.view(
                        fused_layer.fused_fc.in_features,
                        fused_layer.fused_fc.out_features)
                else:
                    fused_layer.fused_fc.weight.value = fused_weight
                fused_layer.fused_fc.weights_scaling_factor.value = fused_weight_scaling_factor

                fused_layer.fused_fc.activation_scaling_factor.value = max(
                    mlp.gate.activation_scaling_factor.raw_value,
                    mlp.fc.activation_scaling_factor.raw_value,
                )
            elif quant_algo is None:
                fused_layer.fused_fc.weight.value = np.concatenate(
                    [
                        mlp.gate.weight.raw_value,
                        mlp.fc.weight.raw_value,
                    ],
                    axis=0,
                )
                if mlp.bias:
                    fused_layer.fused_fc.bias.value = np.concatenate(
                        [mlp.gate.bias.raw_value, mlp.fc.bias.raw_value],
                        axis=0)
            else:
                raise ValueError(f'Unsupported quant algo: {quant_algo}')

            fused_layer.proj = mlp.proj
            fused_layer.inner_layernorm = mlp.inner_layernorm

            mlp_name = name.rsplit('.', 1)[-1]
            setattr(layer, mlp_name, fused_layer)

    return model


def unfuse_qkv_gemm(model: PretrainedModel) -> PretrainedModel:
    '''Split all the models' Attention layer's QKV GEMM into 3 GEMMs layer.q layer.k, layer.v and return the changed model
    '''
    from ..quantization.quantize import quantize

    for name, layer in model.named_modules():
        if isinstance(layer, Attention) and not layer.cross_attention:
            assert layer.tp_size == 1, "please disable manual tp when enable auto parallel"
            if layer.qkv is None:
                continue
            qkv_params = get_init_params(layer.qkv, ColumnLinear)
            qkv_params["bias"] = qkv_params["bias"] is not None
            qkv_params["strict_dtype"] = qkv_params["strict_dtype"] is not None
            q = ColumnLinear(
                **{
                    **qkv_params,
                    "out_features":
                    layer.tp_size * layer.num_attention_heads *
                    layer.attention_head_size,
                })
            k = ColumnLinear(
                **{
                    **qkv_params,
                    "out_features":
                    layer.tp_size * layer.num_attention_kv_heads *
                    layer.attention_head_size,
                })
            v = ColumnLinear(
                **{
                    **qkv_params,
                    "out_features":
                    layer.tp_size * layer.num_attention_kv_heads *
                    layer.attention_head_size,
                })
            q = quantize(q, model.config.quantization)
            k = quantize(k, model.config.quantization)
            v = quantize(v, model.config.quantization)
            if layer.qkv.weight.is_inited():
                qkv_weight = layer.qkv.weight.raw_value
                weights = np.split(qkv_weight, [
                    q.out_features,
                    q.out_features + k.out_features,
                ])
                for gemm, weight in zip([q, k, v], weights):
                    gemm.weight.value = weight
            if layer.qkv.bias is not None and layer.qkv.bias.is_inited():
                qkv_bias = layer.qkv.bias.raw_value
                biases = np.split(qkv_bias, [
                    q.out_features,
                    q.out_features + k.out_features,
                ])
                for gemm, bias in zip([q, k, v], biases):
                    gemm.bias.value = bias
            for name, parameter in layer.qkv._parameters.items():
                if name not in ["weight", "bias"]:
                    for gemm in [q, k, v]:
                        setattr(gemm, name, parameter)
            layer.q = q
            layer.k = k
            layer.v = v
            layer.qkv = None
    return model


def fuse_rg_lru(model: PretrainedModel) -> PretrainedModel:
    for name, rg_lru, parent in model.named_modules_with_parent():
        if isinstance(rg_lru, RgLru):
            fused_layer = FusedRgLru(**get_init_params(rg_lru))
            fused_layer.gate.weight.value = np.concatenate(
                [
                    rg_lru.input_gate.weight.raw_value,
                    rg_lru.recurrent_gate.weight.raw_value,
                ],
                axis=-1,
            )
            fused_layer.gate.bias.value = np.concatenate(
                [
                    rg_lru.input_gate.bias.raw_value,
                    rg_lru.recurrent_gate.bias.raw_value,
                ],
                axis=-1,
            )
            fused_layer.recurrent_param.value = rg_lru.recurrent_param.raw_value
            rg_lru_name = name.rsplit('.', 1)[-1]
            setattr(parent, rg_lru_name, fused_layer)
    return model


def set_prompt_tuning(model: PretrainedModel) -> PretrainedModel:
    '''Replace the given models embedding layer with a PromptTuningEmbedding layer in-place, return the changed model
       Pre-conditions: vocab_embedding exists
       Post-conditions: isinstance(vocab_embedding, PromptTuningEmbedding)

    '''
    for name, embedding, parent in model.named_modules_with_parent():
        layer_name = name.rsplit('.', 1)[-1]
        if layer_name == "vocab_embedding" and isinstance(embedding, Embedding):
            ptuning_embedding = PromptTuningEmbedding(
                **get_init_params(embedding))
            ptuning_embedding.weight.value = embedding.weight.raw_value
            parent.vocab_embedding = ptuning_embedding
    return model


def add_lora(model: PretrainedModel,
             max_lora_rank: Optional[int]) -> PretrainedModel:
    ''' Add lora layers to the Attention/BertAttention/Linear/RowLinear/FusedGatedMLP layers to the given model, return the changed model
    '''
    for name, layer in model.named_modules():
        max_rank = max_lora_rank
        if isinstance(layer, (Attention, BertAttention)):
            if max_rank is None:
                max_rank = min(
                    layer.hidden_size,
                    layer.num_attention_heads * layer.attention_head_size,
                    layer.num_attention_kv_heads * layer.attention_head_size)
            layer.qkv_lora = Lora(
                in_hidden_size=layer.hidden_size,
                out_hidden_sizes=[
                    layer.num_attention_heads * layer.attention_head_size,
                    layer.num_attention_kv_heads * layer.attention_head_size,
                    layer.num_attention_kv_heads * layer.attention_head_size
                ],
                max_low_rank=max_rank,
            )
        if isinstance(layer, (Linear, RowLinear)):
            if max_rank is None:
                max_rank = min(layer.in_features, layer.out_features)
            layer.lora = Lora(
                in_hidden_size=layer.in_features,
                out_hidden_sizes=[layer.out_features],
                max_low_rank=max_rank,
            )
        if isinstance(layer, FusedGatedMLP):
            if max_rank is None:
                max_rank = min(layer.hidden_size,
                               layer.ffn_hidden_size // layer.tp_size)
            layer.lora = Lora(
                in_hidden_size=layer.hidden_size,
                out_hidden_sizes=[
                    layer.ffn_hidden_size // layer.tp_size,
                    layer.ffn_hidden_size // layer.tp_size
                ],
                max_low_rank=max_rank,
            )
    return model


def to_ootb_moe(model: PretrainedModel) -> PretrainedModel:
    ''' Use OOTB MoE instead of MoE plugin, return the changed model
    '''
    for name, layer, parent in model.named_modules_with_parent():
        if isinstance(layer, MOE):
            layer_name = name.rsplit('.', 1)[-1]
            ootb_layer = layer.to(MoeOOTB, model.config)
            setattr(parent, layer_name, ootb_layer)
    return model


def parallelize_embedding(model: PretrainedModel) -> PretrainedModel:
    for name, embedding, parent in model.named_modules_with_parent():
        layer_name = name.rsplit('.', 1)[-1]
        if isinstance(embedding, Embedding) and embedding.tp_group is None:
            init_params = get_init_params(embedding)
            init_params["tp_group"] = model.config.mapping.tp_group
            init_params["tp_size"] = model.config.mapping.tp_size
            init_params["tp_rank"] = model.config.mapping.tp_rank
            init_params["sharding_dim"] = model.config.embedding_sharding_dim
            new_embedding = embedding.__class__(**init_params)
            setattr(parent, layer_name, new_embedding)
    return model


def share_embedding(model: PretrainedModel) -> PretrainedModel:
    lm_head = None
    vocab_embedding = None
    for name, layer in model.named_modules():
        layer_name = name.rsplit('.', 1)[-1]
        if layer_name == "lm_head":
            lm_head = layer
        if layer_name == "vocab_embedding":
            vocab_embedding = layer
        if lm_head is not None and vocab_embedding is not None:
            break

    if lm_head is not None and vocab_embedding is not None:
        lm_head.weight = vocab_embedding.weight
        if (hasattr(vocab_embedding, "per_token_scale")
                and vocab_embedding.per_token_scale is not None):
            lm_head.per_channel_scale = vocab_embedding.per_token_scale
    return model


def set_fp8_context_fhma(model: PretrainedModel) -> PretrainedModel:
    for name, layer in model.named_modules():
        if isinstance(layer, Attention):
            scale = [1.0] / layer.dense.activation_scaling_factor.raw_value
            layer.attention_output_orig_quant_scale = Parameter(
                value=scale.astype(np.float32))
    return model


def optimize_model(
    model: PretrainedModel,
    use_parallel_embedding: bool = False,
    share_embedding_table: bool = False,
    use_ootb_moe: bool = False,
    use_fused_mlp: bool = False,
    gemm_swiglu_plugin_dtype: Optional[str] = None,
    use_fused_rg_lru: bool = False,
    use_unfused_qkv_gemm: bool = False,
    use_prompt_tuning: bool = False,
    use_lora: bool = False,
    max_lora_rank: Optional[int] = None,
    use_fp8_context_fmha: bool = False,
) -> PretrainedModel:
    """
    Run optimization passes on model.
    There are dependencies between some passes,
    so we always run passes in the order of arguments to guarantee the execution order.
    """
    # before weight loading
    if use_parallel_embedding:
        model = parallelize_embedding(model)
    if share_embedding_table:
        model = share_embedding(model)

    # After weight loading
    if use_ootb_moe:
        model = to_ootb_moe(model)
    if use_fused_mlp:
        model = fuse_gate_mlp(model, gemm_swiglu_plugin_dtype)
    if use_fused_rg_lru:
        model = fuse_rg_lru(model)
    if use_unfused_qkv_gemm:
        model = unfuse_qkv_gemm(model)
    if use_prompt_tuning:
        model = set_prompt_tuning(model)
    if use_lora:
        model = add_lora(model, max_lora_rank)
    if use_fp8_context_fmha:
        model = set_fp8_context_fhma(model)
    return model


def preprocess_weights(weights: Dict[str, torch.Tensor],
                       model_config: PretrainedConfig,
                       from_pruned=False) -> Dict[str, torch.Tensor]:
    quant_algo = model_config.quantization.quant_algo
    kv_cache_quant_algo = model_config.quantization.kv_cache_quant_algo

    # INT4_AWQ
    if quant_algo == QuantAlgo.W4A8_AWQ or quant_algo == QuantAlgo.W4A16_AWQ:
        preprocessor = torch.ops.trtllm.preprocess_weights_for_mixed_gemm
        if quant_algo == QuantAlgo.W4A8_AWQ:
            activation_type = torch.float8_e4m3fn
        elif quant_algo == QuantAlgo.W4A16_AWQ:
            activation_type = torch.float16
        for name, param in weights.items():
            if from_pruned and param.numel() == 0:
                continue
            if name.endswith('weight') and param.dtype == torch.int8:
                dtype = torch.float16
                if model_config.dtype == "bfloat16":
                    dtype = torch.bfloat16
                weights[name] = preprocessor(param.T.contiguous(),
                                             torch.quint4x2,
                                             activation_type).view(dtype)
            if name.endswith('weights_scaling_factor'):
                weights[name] = param.T.contiguous().to(
                    str_dtype_to_torch(model_config.dtype))
            if name.endswith('prequant_scaling_factor'):
                weights[name] = param.reshape(1, -1)
            if model_config.mapping.tp_rank > 0:
                if name.endswith('attention.dense.bias') or name.endswith(
                        'mlp.proj.bias'):
                    weights[name] = torch.zeros_like(param)

        if quant_algo == QuantAlgo.W4A8_AWQ:
            for name in list(weights):
                if name.endswith('weights_scaling_factor'):
                    activation_scaling_factor = weights.pop(
                        name.replace('weights_scaling_factor',
                                     'activation_scaling_factor'))
                    weights_scaling_factor_2 = weights.pop(
                        name.replace('weights_scaling_factor',
                                     'weights_scaling_factor_2'))
                    weights[name] /= weights_scaling_factor_2
                    weights[name.replace(
                        'weights_scaling_factor',
                        'prequant_scaling_factor')] /= activation_scaling_factor
                    weights[name.replace(
                        'weights_scaling_factor', 'alpha'
                    )] = activation_scaling_factor * weights_scaling_factor_2

    # FP8
    elif quant_algo == QuantAlgo.FP8:
        for name, param in weights.items():
            if name.endswith('weight') and param.dtype == torch.int8:
                weights[name] = param.view(torch.float8_e4m3fn)
        # lm_head is not quantized to FP8
        if "lm_head.weight" in weights:
            assert weights['lm_head.weight'].dtype == str_dtype_to_torch(
                model_config.dtype)
            weights.pop('lm_head.weights_scaling_factor', None)
            weights.pop('lm_head.activation_scaling_factor', None)

    elif quant_algo in [QuantAlgo.W4A16, QuantAlgo.W8A16]:
        weights = weight_only_quantize_dict(weights=weights,
                                            quant_algo=quant_algo,
                                            plugin=True)

    # FP8 kv_cache_scaling_factor is always 1.0
    if kv_cache_quant_algo == QuantAlgo.FP8:
        for name, param in weights.items():
            if name.endswith('kv_cache_scaling_factor'):
                weights[name] = torch.tensor([1.0], dtype=torch.float32)

    # Parallel block rowlinear should not have duplicate bias.
    elif model_config.architecture == 'GPTJForCausalLM':
        if model_config.mapping.tp_rank > 0:
            for name, param in weights.items():
                if 'attention.dense.bias' in name or 'mlp.proj.bias' in name:
                    weights[name] = torch.zeros_like(param)

    # For share_embedding_table
    if model_config.share_embedding_table:
        if "lm_head.weight" in weights and "transformer.vocab_embedding.weight" in weights:
            if (weights["lm_head.weight"] -
                    weights["transformer.vocab_embedding.weight"]).any():
                logger.warning(
                    "lm_head.weight and transformer.vocab_embedding.weight are not identical, "
                    "share_embedding_table cannot be enabled; setting share_embedding_table=False."
                )
                model_config.share_embedding_table = False
            else:
                weights.pop("lm_head.weight")
